import numpy as np
import random
import time
import matplotlib.pyplot as plt
import pickle

GRID_HEIGHT = 8
GRID_WIDTH = 10
OBSTACLES = [(5, 2), (5, 5), (2, 4), (6, 7)]
GOAL = (7, 9)
START = (7, 0)

ACTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)] 
NUM_ACTIONS = len(ACTIONS)
H = np.linspace(-150, 100, 251)
gamma = 0.99
threshold = 1e-4
reward_default = -1
reward_goal = 50
reward_obstacle = -50
random_probability = 0.3 
num_episodes = 15000
q = 0.1
SEED = 0

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    
def is_valid(pos):
    x, y = pos
    return 0 <= x < GRID_HEIGHT and 0 <= y < GRID_WIDTH

def step(state, action_index):
    if random.random() < random_probability:
        other_actions = list(range(NUM_ACTIONS))
        other_actions.remove(action_index)
        action_index = random.choice(other_actions)

    action = ACTIONS[action_index]
    next_state = (state[0] + action[0], state[1] + action[1])
    done = False
    
    if not is_valid(next_state):
        next_state = state
        reward = reward_default
    elif next_state in OBSTACLES:
        done = True
        reward = np.random.normal(reward_obstacle, 1)
    elif next_state == GOAL:
        done = True
        reward = np.random.normal(reward_goal, 1)
    else:
        reward = reward_default

    return next_state, reward, done

def choose_action(Q, state, epsilon):
    if np.random.rand() < epsilon:
        return random.randint(0, NUM_ACTIONS - 1)  
    else:
        return np.argmax(Q[state[0], state[1]])   

def choose_action_PCVaR(Q_cvar, M, state, idx, epsilon):
    if np.random.rand() < epsilon:
        return random.randint(0, NUM_ACTIONS - 1) 
    else:
        q_values = Q_cvar[state[0], state[1], idx, :] - H[idx]*M[state[0], state[1], idx, :]
        max_q = np.max(q_values)
        max_actions = np.where(q_values == max_q)[0]
        return np.random.choice(max_actions)

def step_deterministic(state, action):
    dx, dy = ACTIONS[action]
    next_state = (state[0] + dx, state[1] + dy)
    if not is_valid(next_state):
        next_state = state
    if next_state in OBSTACLES:
        return next_state, reward_obstacle, True
    elif next_state == GOAL:
        return next_state, reward_goal, True
    else:
        return next_state, reward_default, False

def sarsa(episodes=100000, alpha=0.01, epsilon_min=0.0, max_steps=500, decay_episodes = 25000):
    Q = np.zeros((GRID_HEIGHT, GRID_WIDTH, NUM_ACTIONS)) + 10
    rewards = []
    epsilon = epsilon_min
    for episode in range(episodes):
        if episode < decay_episodes:
            epsilon = 1.0 - (episode / decay_episodes)
        else:
            epsilon = 0.0

        state = START
        action = choose_action(Q, state, epsilon)
        done = False
        total_reward = 0
        
        for t in range(max_steps):
            next_state, reward, done = step(state, action)
            next_action = choose_action(Q, next_state, epsilon)

            x, y = state
            nx, ny = next_state

            Q[x, y, action] += alpha * (
                reward + Q[nx, ny, next_action] - Q[x, y, action]
            )
            total_reward += reward
            state = next_state
            action = next_action
            
            rewards.append(total_reward)
            
            if done:
                break
        rewards.append(total_reward)

    return Q, rewards

def Pre_train(Q, num_simul=10000):
    Q_cvar_sum = np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))  
    M_sum = np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))   
    Count = np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))     
    rewards = []
    for episode in range(num_simul):
        while True:
            state = (random.randint(0, 7), random.randint(0, 9))
            if state not in OBSTACLES and state != GOAL:
                break
        total_reward = 0
        done = False
        trajectory = []

        time_step = 0
        action = choose_action(Q, state, 0.6)
        while not done:
            next_state, reward, done = step(state, action)
            trajectory.append((total_reward, state, action, reward))
            state = next_state
            total_reward += reward
            time_step +=1
            if time_step > 1000:
                break
            action = choose_action(Q, state, 0.0)
        rewards.append(total_reward)

        G = 0
        T = len(trajectory) - 1
        Remain_sum = []
        for _, _, _, r in reversed(trajectory):
            G = r + G
            Remain_sum.insert(0, G)
            T -= 1

        for t, ((sum_r, s, a, r), Gt) in enumerate(zip(trajectory, Remain_sum)):
            s_x, s_y = s
            for i, h in enumerate(H):
                idx = np.clip(i - int(round(sum_r)),0,len(H)-1)
                indicator = 1.0 if Gt <= H[idx] else 0.0
                Count[s_x, s_y, idx, a] += 1
                M_sum[s_x, s_y, idx, a] += indicator
                Q_cvar_sum[s_x, s_y, idx, a] += Gt * indicator
                if ((i - int(round(sum_r))) < 0) or ((i - int(round(sum_r))) > (len(H) - 1)):
                    break

    M = np.zeros_like(M_sum)
    Q_cvar = np.zeros_like(Q_cvar_sum) 

    valid = Count > 0
    M[valid] = M_sum[valid] / Count[valid]
    Q_cvar[valid] = Q_cvar_sum[valid] / Count[valid]
    Q_cvar[7,9, :, :] = 0
    Q_cvar[5,2, :, :] = 0
    Q_cvar[5,5, :, :] = 0
    Q_cvar[2,4, :, :] = 0
    Q_cvar[6,7, :, :] = 0
    M[7,9, 151:, :] = 1
    M[5,2, 151:, :] = 1
    M[5,5, 151:, :] = 1
    M[2,4, 151:, :] = 1
    M[6,7, 151:, :] = 1
    
    return Q_cvar, M, rewards

def get_policy_path_PCVaR(Q_cvar, M, eta_index):
    state = START
    path = [state]

    for _ in range(300): 
        q_values = Q_cvar[state[0], state[1], eta_index, :] - H[eta_index]*M[state[0], state[1], eta_index, :]
        max_q = np.max(q_values)
        max_actions = np.where(q_values == max_q)[0]
        action = np.random.choice(max_actions)
        next_state, reward, done = step_deterministic(state, action)
        path.append(next_state)
        if next_state == GOAL or next_state in OBSTACLES:
            break
        state = next_state
        eta_index = np.clip(eta_index - int(round(reward)), 0, len(H) -1)

    return path

def draw_cvar_path(V, path, obstacles=OBSTACLES, goal=GOAL, start=START):
    fig, ax = plt.subplots(figsize=(14, 5))
    im = ax.imshow(V, cmap='viridis', origin='upper')
    plt.colorbar(im, ax=ax, fraction=0.02, pad=0.04, label='Visit freq')

    ax.set_xticks(np.arange(GRID_WIDTH))
    ax.set_yticks(np.arange(GRID_HEIGHT))
    ax.set_xticks(np.arange(-0.5, GRID_WIDTH, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, GRID_HEIGHT, 1), minor=True)
    ax.grid(which='minor', color='black', linewidth=1.5)
    ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

    for (y, x) in obstacles:
        ax.add_patch(plt.Rectangle((x - 0.5, y - 0.5), 1, 1, color='black'))

    sy, sx = start
    gy, gx = goal
    ax.add_patch(plt.Circle((sx, sy), 0.3, color='green'))
    ax.add_patch(plt.Circle((gx, gy), 0.3, color='red'))
    offset = 0.3   
    length = 0.6   

    for i in range(len(path) - 1):
        y, x = path[i]
        y2, x2 = path[i + 1]
        dy, dx = y2 - y, x2 - x
        start_x = x - dx * offset
        start_y = y - dy * offset
        ax.arrow(
            start_x, start_y, dx * length, dy * length,
            head_width=0.3, head_length=0.3,
            fc='yellow', ec='yellow', linewidth=4,
            length_includes_head=True,
            alpha=0.9
        )
    plt.show()

def update_PCVaR(Q_cvar, M, H, lr1, lr2, trajectory):
    Q_est = np.zeros_like(Q_cvar)
    M_est = np.zeros_like(M)
    count = np.zeros_like(Q_cvar)
    for t, ((sum_r, s, a, r, s_next)) in enumerate(trajectory):
        for i in range(len(H)):
            idx = i - int(round(sum_r))
            if ((idx < 0) or (idx > (len(H) - 1))):
                break
            count[s[0],s[1],idx, a] += 1
            next_idx = np.clip(idx - int(round(r)), 0, len(H) -1)
            q_values = Q_cvar[s_next[0], s_next[1], next_idx, :] - H[next_idx]*M[s_next[0], s_next[1], next_idx, :]     
            max_q = np.max(q_values)
            max_actions = np.where(q_values == max_q)[0]
            a_next =  np.random.choice(max_actions)
            Q_est[s[0],s[1],idx,a] += (Q_cvar[s_next[0], s_next[1],next_idx, a_next] + M[s_next[0], s_next[1],next_idx, a_next]*r)
            M_est[s[0],s[1],idx, a] += M[s_next[0], s_next[1],next_idx, a_next]
    valid = count > 0
    Q_cvar[valid] += lr1*(Q_est[valid]/count[valid] - Q_cvar[valid])  
    M[valid] += lr2*(M_est[valid]/count[valid] - M[valid]) 
    return Q_cvar, M

def PCVaR_Q_learning_eta(Q_cvar, M, decay_episodes=2000):
    start_time = time.time() 
    alpha_theta = 0.01 
    alpha_phi = 0.01
    eta_set = H
    rewards = []
    eta_index= int(eta_RN) + 150
    eta = eta_RN
    visit_count = np.zeros((GRID_HEIGHT, GRID_WIDTH))
    for episode in range(num_episodes):
        epsilon_t = max(1.0 - (episode / decay_episodes), 0.0)
        Trajectory = []
        state = START
        done = False
        total_reward = 0
        eta_t_idx = eta_index
        t = 0
        while not done:
            visit_count[state[0], state[1]] += 1
            action = choose_action_PCVaR(Q_cvar, M, state, eta_t_idx, epsilon_t)
            next_state, reward, done = step(state, action)
            Trajectory.append([total_reward, state, action, reward, next_state])
            total_reward += reward
            eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) 
            state = next_state
            t += 1
            if t > 1000:
                done = True
        rewards.append(total_reward)
        Q_cvar, M = update_PCVaR(Q_cvar, M, H, alpha_theta, alpha_phi, Trajectory)
        if((episode + 1)%1000 == 0):
            var_est = -10000000000
            eta = 0
            Q_start = Q_cvar[START[0], START[1]]
            M_start = M[START[0], START[1]]
            for i, h in enumerate(H):
                val_all = h * (q - M_start[i]) + Q_start[i]
                max_val = np.max(val_all)
                if max_val > var_est:
                    var_est = max_val
                    eta = h
                    eta_index = i
        if((episode + 1)%1000 == 0):
            rewards_test = []
            for iter in range(10000):
                state = START
                done = False
                eta_t_idx = int(eta) + 150
                total_reward = 0 
                t = 0
                while not done:
                    action = choose_action_PCVaR(Q_cvar, M, state, eta_t_idx, 0.0)
                    next_state, reward, done = step(state, action)
                    total_reward += reward
                    eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) 
                    state = next_state
                    t += 1
                    if t > 1000:
                        done = True
                rewards_test.append(total_reward)
            rewards_test = np.array(rewards_test)
            var_test = np.percentile(rewards_test, q * 100 )
            cvar_test = np.mean(rewards_test[rewards_test <= var_test])
        if((episode + 1) == 2000 or (episode + 1) == 12000):
            path = get_policy_path_PCVaR(Q_cvar, M, eta_index)
            draw_cvar_path(visit_count, path)
    return Q_cvar, M

def PCVaR_Q_learning_etasample(Q_cvar, M, decay_episodes=2000):
    start_time = time.time() 
    alpha_theta = 0.01 
    alpha_phi = 0.01
    eta_set = H
    rewards = []
    eta_index= int(eta_RN) + 150
    eta = eta_RN
    sigma = 45
    visit_count = np.zeros((GRID_HEIGHT, GRID_WIDTH))
    for episode in range(num_episodes):
        epsilon_t = max(1.0 - (episode / decay_episodes), 0.0)
        Trajectory = []
        state = START
        done = False
        total_reward = 0
        eta_t_idx = eta_index
        t = 0
        while not done:
            visit_count[state[0], state[1]] += 1
            action = choose_action_PCVaR(Q_cvar, M, state, eta_t_idx, epsilon_t)
            next_state, reward, done = step(state, action)
            Trajectory.append([total_reward, state, action, reward, next_state])
            total_reward += reward
            eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) 
            state = next_state
            t += 1
            if t > 1000:
                done = True
        rewards.append(total_reward)
        Q_cvar, M = update_PCVaR(Q_cvar, M, H, alpha_theta, alpha_phi, Trajectory)
        if((episode + 1)%1000 == 0):
            var_est = -10000000000
            eta = 0
            Q_start = Q_cvar[START[0], START[1]]
            M_start = M[START[0], START[1]]
            for i, h in enumerate(H):
                val_all = h * (q - M_start[i]) + Q_start[i]
                max_val = np.max(val_all)
                if max_val > var_est:
                    var_est = max_val
                    eta = h
            sigma = max(45*(3 - ((episode + 1) // 2000)),0)
        sample_eta = np.random.normal(loc=eta, scale=sigma)
        sample_eta = np.clip(sample_eta, eta - 2*sigma, eta + 2*sigma)
        eta_index = np.clip(int(round(sample_eta) + 150), 0,250)
        
        if((episode + 1)%1000 == 0):
            rewards_test = []
            for iter in range(10000):
                state = START
                done = False
                eta_t_idx = int(eta) + 150
                total_reward = 0 
                t = 0
                while not done:
                    action = choose_action_PCVaR(Q_cvar, M, state, eta_t_idx, 0.0)
                    next_state, reward, done = step(state, action)
                    total_reward += reward
                    eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) 
                    state = next_state
                    t += 1
                    if t > 1000:
                        done = True
                rewards_test.append(total_reward)
            rewards_test = np.array(rewards_test)
            var_test = np.percentile(rewards_test, q * 100 )
            cvar_test = np.mean(rewards_test[rewards_test <= var_test])
        if((episode + 1) == 2000 or (episode + 1) == 12000):
            path = get_policy_path_PCVaR(Q_cvar, M, eta_index)
            draw_cvar_path(visit_count, path)
    return Q_cvar, M


set_seed(SEED)
Q_sarsa, rewards_sarsa = sarsa()

set_seed(SEED)
Pre_Q_cvar, Pre_M, Pre_rewards = Pre_train(Q_sarsa, 50000)
eta_RN = np.quantile(Pre_rewards, q)

Q_cvar_zeros =  np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))
M_zeros =  np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))


## eta sampling X, pretrained parameters X
set_seed(SEED)
Q_cvarXX, MXX = PCVaR_Q_learning_eta(np.copy(Q_cvar_zeros), np.copy(M_zeros))
## eta sampling O, pretrained parameters X
set_seed(SEED)
Q_cvarOX, MOX= PCVaR_Q_learning_etasample(np.copy(Q_cvar_zeros), np.copy(M_zeros))
## eta sampling X, pretrained parameters 0
set_seed(SEED)
Q_cvarXO, MXO = PCVaR_Q_learning_eta(np.copy(Pre_Q_cvar), np.copy(Pre_M))
## eta sampling O, pretrained parameters O
set_seed(SEED)
Q_cvarOO, MOO = PCVaR_Q_learning_etasample(np.copy(Pre_Q_cvar), np.copy(Pre_M))


# In[ ]:




